# -*- coding:utf-8 -*-

import tensorflow as tf

"""
优化，损失函数相关
"""


def get_optimizer(optimizer_name, learning_rate, decay=0.9, momentum=0.9):
    """
    优化器选择
    :param optimizer_name:
    :param learning_rate:
    :param decay:
    :param momentum:
    :return:
    """
    if optimizer_name == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=momentum)
    elif optimizer_name == 'rmsprop':
        optimizer = tf.train.RMSPropOptimizer(learning_rate, decay=decay, momentum=momentum)
    elif optimizer_name == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate)
    elif optimizer_name == 'sgd':
        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    else:
        raise ValueError('不支持该类型优化器')
    return optimizer


def get_learning_rate(global_step, learning_rate):
    """
    学习率选择
    :param global_step:
    :param learning_rate: LearningRate
    :return:
    """
    if learning_rate['name'] == 'fixed':  # 固定学习率
        lr = tf.constant(learning_rate['value'], dtype=tf.float32)
    elif learning_rate['name'] == 'exp_decay':  # 指数衰减学习率 learning x decay_rate ^ step/decay_step
        lr = tf.train.exponential_decay(
            learning_rate['value'], global_step, learning_rate['decay_step'],
            learning_rate['decay_rate'], staircase=True)
    else:
        raise ValueError('不支持该类型学习率')
    return lr
